区间dp

一.概念

对于一段区间求最优解,且该区间可以分为几个小区间的最优解合并(最优子结构)。

二.基本思路

从小到大枚举区间长度,用小区间合并大区间。

代码实现也很简单:

for( int len = 1 ; len <= n ; len ++ ) //枚举长度
		for( int l = 1 ; l + len - 1 <= n ; l ++ ) { //枚举当前区间左端点
			int r = l + len - 1; //算出右端点
			for( int k = l ; k < r ; k ++ ) { //枚举分割点
				
			}
		}

三.例题

1.P1880 [NOI1995]石子合并

这是一道区间dp\text{dp}的经典入门题板题

因为是一个环状,所以我们先破环为链。

dp[0/1][l][r]dp[0/1][l][r] 表示合并 llrr 这个区间内的石子的最小/最大得分。

那么有:

dp[0][l][r]=min(dp[0][l][k]+dp[0][k+1][r]+a[l...r])   (lk<r)dp[0][l][r]=min(dp[0][l][k]+dp[0][k+1][r]+a[l...r])~~~ ( l \le k < r)

dp[1][l][r]=max(dp[1][l][k]+dp[1][k+1][r]+a[l...r])   (lk<r)dp[1][l][r]=max(dp[1][l][k]+dp[1][k+1][r]+a[l...r])~~~ ( l \le k < r)

最后还原为环,答案为 min(dp[0][i][i+n1]),max(dp[1][i][i+n1])min(dp[0][i][i+n-1]),max(dp[1][i][i+n-1])

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

const int MAXN = 200;
int n , a[ MAXN + 5 ] , sum[ MAXN + 5 ];
int dp[ 2 ][ MAXN + 5 ][ MAXN + 5 ];

int main( ) {
	scanf("%d",&n);
	memset( dp[ 0 ] , 0x3f , sizeof( dp[ 0 ] ) );
    memset( dp[ 1 ] , 0 , sizeof( dp[ 1 ] ) );
	for( int i = 1 ; i <= n ; i ++ ) {
		scanf("%d",&a[ i ]);
		sum[ i ] = sum[ i - 1 ] + a[ i ];
		dp[ 0 ][ i ][ i ] = dp[ 1 ][ i ][ i ] = 0;
	}
    for( int i = n + 1 ; i <= 2 * n ; i ++ ) {
        a[ i ] = a[ i - n ];
        sum[ i ] = sum[ i - 1 ] + a[ i ];
        dp[ 0 ][ i ][ i ] = dp[ 1 ][ i ][ i ] = 0;
    }
	
	for( int len = 1 ; len <= n ; len ++ )
		for( int l = 1 ; l + len - 1 <= 2 * n ; l ++ ) {
			int r = l + len - 1;
			for( int k = l ; k < r ; k ++ ) {
                dp[ 0 ][ l ][ r ] = min( dp[ 0 ][ l ][ r ] , dp[ 0 ][ l ][ k ] + dp[ 0 ][ k + 1 ][ r ] + sum[ r ] - sum[ l - 1 ] );
                dp[ 1 ][ l ][ r ] = max( dp[ 1 ][ l ][ r ] , dp[ 1 ][ l ][ k ] + dp[ 1 ][ k + 1 ][ r ] + sum[ r ] - sum[ l - 1 ] );
            }	
		}
    
    int Min = 0x3f3f3f3f , Max = 0;
    for( int i = 1 ; i <= n ; i ++ ) {
        Min = min( Min , dp[ 0 ][ i ][ n + i - 1 ] );
        Max = max( Max , dp[ 1 ][ i ][ n + i - 1 ] );
    }
	printf("%d\n%d\n", Min , Max );
	return 0;
} 

2.P4342 [IOI1998]Polygon

首先仍然是破环为链。

dp[l][r]dp[l][r] 表示将 llrr 的区间的数进行运算的最大值。

kk 为当前的分割点 , 即需要将 dp[l][k]dp[l][k]dp[k+1][r]dp[k+1][r] 合并。

1.当 op[k]= +op[k]=~'+' 时 , 显然有:

dp[l][r]=max(dp[l][r],dp[l][k]+dp[k+1][r])dp[l][r]=max(dp[l][r],dp[l][k]+dp[k+1][r])

2.当 op[k]= op[k]=~'*' 时 ,

dp[l][r]=max(dp[l][r],dp[l][k]×dp[k+1][r])dp[l][r]=max(dp[l][r],dp[l][k] \times dp[k+1][r])

特殊的 , dp[i][i]=a[i]dp[i][i]=a[i]

但是我们忘了 a[i]a[i] 可以为负数 , 两个负数相乘是可能大于两个正数相乘的。

解决方法也很简单,只需要再记录一个最小值即可。

加法转移一样,乘法转移有一点变化。记 dp[0/1][l][r]dp[0/1][l][r] 为合并 llrr
区间的最小值/最大值。

{dp[0][l][r]=min(dp[0][l][k]×dp[0][k+1][r],dp[0][l][k]×dp[1][k+1][r],dp[1][l][k]×dp[0][k+1][r],dp[1][l][k]×dp[1][k+1][r])dp[1][l][r]=max(dp[0][l][k]×dp[0][k+1][r],dp[0][l][k]×dp[1][k+1][r],dp[1][l][k]×dp[0][k+1][r],dp[1][l][k]×dp[1][k+1][r])\begin{cases} dp[0][l][r]=min( dp[0][l][k] \times dp[0][k+1][r] , dp[0][l][k] \times dp[1][k+1][r] , dp[1][l][k] \times dp[0][k+1][r] , dp[1][l][k] \times dp[1][k+1][r] ) \\ dp[1][l][r]=max( dp[0][l][k] \times dp[0][k+1][r] , dp[0][l][k] \times dp[1][k+1][r] , dp[1][l][k] \times dp[0][k+1][r] , dp[1][l][k] \times dp[1][k+1][r] ) \end{cases}

最后统计最大值即可。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

const int MAXN = 100;
int n , a[ MAXN + 5 ] , dp[ 2 ][ MAXN + 5 ][ MAXN + 5 ];
char op[ MAXN + 5 ];

int Get( int l , int r , int k , int f ) {
    if( f == 0 )
        return min( min( dp[ 0 ][ l ][ k ] * dp[ 0 ][ k + 1 ][ r ] , dp[ 1 ][ l ][ k ] * dp[ 1 ][ k + 1 ][ r ] ) , min( dp[ 0 ][ l ][ k ] * dp[ 1 ][ k + 1 ][ r ] , dp[ 1 ][ l ][ k ] * dp[ 0 ][ k + 1 ][ r ] ) );
    else
        return max( max( dp[ 0 ][ l ][ k ] * dp[ 0 ][ k + 1 ][ r ] , dp[ 1 ][ l ][ k ] * dp[ 1 ][ k + 1 ][ r ] ) , max( dp[ 0 ][ l ][ k ] * dp[ 1 ][ k + 1 ][ r ] , dp[ 1 ][ l ][ k ] * dp[ 0 ][ k + 1 ][ r ] ) );
}
int main( ) {
	scanf("%d\n",&n);
    for( int i = 1 ; i <= n ; i ++ ) {
        scanf("%c %d",&op[ i ],&a[ i ]) , getchar( );
        op[ i + n ] = op[ i ] , a[ i + n ] = a[ i ];
    }
    
    memset( dp[ 0 ] , 0x3f , sizeof( dp[ 0 ] ) );
    memset( dp[ 1 ] , 0xcf , sizeof( dp[ 1 ] ) );
    for( int i = 1 ; i <= 2 * n ; i ++ ) dp[ 0 ][ i ][ i ] = dp[ 1 ][ i ][ i ] = a[ i ];

    for( int len = 1 ; len <= n ; len ++ )
        for( int l = 1 ; l + len - 1 <= 2 * n ; l ++ ) {
            int r = l + len - 1;
            for( int k = l ; k < r ; k ++ ) {
                if( op[ k + 1 ] == 't' ) {
                    dp[ 0 ][ l ][ r ] = min( dp[ 0 ][ l ][ r ] , dp[ 0 ][ l ][ k ] + dp[ 0 ][ k + 1 ][ r ] );
                    dp[ 1 ][ l ][ r ] = max( dp[ 1 ][ l ][ r ] , dp[ 1 ][ l ][ k ] + dp[ 1 ][ k + 1 ][ r ] );
                }
                else {
                    dp[ 0 ][ l ][ r ] = min( dp[ 0 ][ l ][ r ] , Get( l , r , k , 0 ) );
                    dp[ 1 ][ l ][ r ] = max( dp[ 1 ][ l ][ r ] , Get( l , r , k , 1 ) );
                }
            }
        }
    
    int Ans = 0;
    for( int i = 1 ; i <= n ; i ++ )
        Ans = max( Ans , dp[ 1 ][ i ][ n + i - 1 ] );
    printf("%d\n", Ans );
    for( int i = 1 ; i <= n ; i ++ )
        if( dp[ 1 ][ i ][ n + i - 1 ] == Ans ) printf("%d ",i);
	return 0;
} 

3.CF149D Coloring Brackets

为了方便讨论,先处理出每个括号对应的匹配,记为 pip_i

dp[l][r][0/1/2][0/1/2]dp[l][r][0/1/2][0/1/2] 表示现在处理的区间为 llrr 并且是一个匹配序列,左端点红/蓝/不染,右端点红/蓝/不染,那么有三种情况:

1.r=l+1r=l+1

即匹配的括号是相邻的,形如 ()'()',那么只需要将任意端点染一个颜色,只有一种方案。

dp[l][r][0][1]=dp[l][r][0][2]=dp[l][r][1][0]=dp[l][r][2][0]=1dp[l][r][0][1]=dp[l][r][0][2]=dp[l][r][1][0]=dp[l][r][2][0]=1

2.p[l]=r &&p[r]=lp[l]=r ~ \&\& p[r]=l

即合法序列的两端是匹配的,去掉两端后的区间显然也是合法的,所以我们先递归处理 dp[l+1][r1]dp[l+1][r-1]

然后保证 lll+1l+1 的颜色不同 , rrr1r-1 的颜色不同,进行转移即可。

3.p[l]=rp[l] \not= r

即合法序列由两部分构成:llp[l]p[l]p[l]+1p[l]+1rr,同样先递归处理。

然后枚举分成的两部分的端点颜色,保证左区间的右端点和右区间的左端点颜色不同(但可以都不染),就可以转移了。

#include <stack>
#include <cstdio>
#include <cstring>
using namespace std; 
#define Mod 1000000007

const int MAXN = 700;
int n , p[ MAXN + 5 ] , dp[ MAXN + 5 ][ MAXN + 5 ][ 3 ][ 3 ];
char str[ MAXN + 5 ];
stack< int > s;

void dfs( int l , int r ) {
	if( r == l + 1 ) {
		dp[ l ][ r ][ 0 ][ 1 ] = dp[ l ][ r ][ 1 ][ 0 ] = 1;
		dp[ l ][ r ][ 0 ][ 2 ] = dp[ l ][ r ][ 2 ][ 0 ] = 1;
	}
	else if( p[ l ] == r ) {
		dfs( l + 1 , r - 1 );
		for( int l1 = 0 ; l1 <= 2 ; l1 ++ )
			for( int r1 = 0 ; r1 <= 2 ; r1 ++ ) {
				if( l1 != 1 ) dp[ l ][ r ][ 1 ][ 0 ] = ( dp[ l ][ r ][ 1 ][ 0 ] + dp[ l + 1 ][ r - 1 ][ l1 ][ r1 ] ) % Mod;
				if( l1 != 2 ) dp[ l ][ r ][ 2 ][ 0 ] = ( dp[ l ][ r ][ 2 ][ 0 ] + dp[ l + 1 ][ r - 1 ][ l1 ][ r1 ] ) % Mod;
				if( r1 != 1 ) dp[ l ][ r ][ 0 ][ 1 ] = ( dp[ l ][ r ][ 0 ][ 1 ] + dp[ l + 1 ][ r - 1 ][ l1 ][ r1 ] ) % Mod;
				if( r1 != 2 ) dp[ l ][ r ][ 0 ][ 2 ] = ( dp[ l ][ r ][ 0 ][ 2 ] + dp[ l + 1 ][ r - 1 ][ l1 ][ r1 ] ) % Mod;
			}
	}
	else {
		dfs( l , p[ l ] ) , dfs( p[ l ] + 1 , r );
		
		for( int l1 = 0 ; l1 <= 2 ; l1 ++ )
			for( int r1 = 0 ; r1 <= 2 ; r1 ++ )
				for( int l2 = 0 ; l2 <= 2 ; l2 ++ )
					for( int r2 = 0 ; r2 <= 2 ; r2 ++ )
						if( r1 == 0 || r1 != l2 )
							dp[ l ][ r ][ l1 ][ r2 ] = ( dp[ l ][ r ][ l1 ][ r2 ] + 1ll * dp[ l ][ p[ l ] ][ l1 ][ r1 ] * dp[ p[ l ] + 1 ][ r ][ l2 ][ r2 ] ) % Mod;
	}
}

int main( ) {
//	freopen("coloring.in","r",stdin);
//	freopen("coloring.out","w",stdout);
	
	scanf("%s", str + 1 ); n = strlen( str + 1 );
	for( int i = 1 ; i <= n ; i ++ ) {
		if( str[ i ] == '(' ) s.push( i );
		if( str[ i ] == ')' ) {
			int t = s.top(); s.pop();
			p[ i ] = t , p[ t ] = i;
		}
	}
	
	dfs( 1 , n );
	
	int Ans = 0;
	for( int i = 0 ; i <= 2 ; i ++ )
		for( int j = 0 ; j <= 2 ; j ++ )
			Ans = ( Ans + dp[ 1 ][ n ][ i ][ j ] ) % Mod;
			
	printf("%d", Ans );
	return 0;
}

4.P3736 [HAOI2016]字符合并

为了使结果更大,我们肯定会合并至无法合并为止。

dp[l][r][S]dp[l][r][S] 表示 将 llrr 这段区间最终合并为 SS 的最大分数。显然 SS 的位数小于 kk , 考虑状压。

首先 dp[i][i][str[i]]=0dp[i][i][str[i]]=0, 其余赋为极小值避免无效转移。

区间 dp\text{dp} 的套路,枚举分割点 kk , 且 kk 的左边贡献 SS , 右边贡献 1/01/0 ,组合得到 S<<1(S<<1)S<<1(S<<|1) 由于 数字1/01/0 只会由位数为 1+p(k1)1+p(k-1) 的字符串产生,所以 kk 每次减 k1k-1 即可。

然后有:

{dp[l][r][S<<1]=max(dp[l][k][S]+dp[k+1][r][0])dp[l][r][S<<11]=max(dp[l][k][S]+dp[k+1][r][1])\begin{cases} dp[l][r][S<<1]=max(dp[l][k][S]+dp[k+1][r][0]) \\ dp[l][r][S<<1|1]=max(dp[l][k][S]+dp[k+1][r][1]) \end{cases}

最后有一种特殊情况,当区间长度恰好为 kk 时直接合并成一个字符。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define INF 1e15

const int MAXN = 300;
int n , m , k , str[ MAXN + 5 ] , c[ MAXN + 5 ] , w[ MAXN + 5 ];
long long dp[ MAXN + 5 ][ MAXN + 5 ][ MAXN + 5 ];

int main( ) {
	scanf("%d %d",&n,&k); m = 1 << k;
	for( int i = 1 ; i <= n ; i ++ )
		scanf("%d",&str[ i ]);
	for( int i = 0 ; i < m ; i ++ )
		scanf("%d %d",&c[ i ],&w[ i ]);
	
	memset( dp , 0xcf , sizeof( dp ) );
	for( int l = n ; l >= 1 ; l -- )
		for( int r = l ; r <= n ; r ++ ) {
			if( l == r ) {
				dp[ l ][ r ][ str[ l ] ] = 0;
				continue;
			}
			
			int len = ( r - l ) % ( k - 1 );
			if( len == 0 ) len = k - 1;
			
			for( int i = r - 1 ; i >= l ; i -= k - 1 )
				for( int S = 0 ; S < 1 << len ; S ++ ) {
					dp[ l ][ r ][ S << 1 ] = max( dp[ l ][ r ][ S << 1 ] , dp[ l ][ i ][ S ] + dp[ i + 1 ][ r ][ 0 ] );
					dp[ l ][ r ][ S << 1 | 1 ] = max( dp[ l ][ r ][ S << 1 | 1 ] , dp[ l ][ i ][ S ] + dp[ i + 1 ][ r ][ 1 ] );
				}
			
			if( len == k - 1 ) {
				long long tmp[ 2 ]; tmp[ 0 ] = -INF , tmp[ 1 ] = -INF; 
				for( int S = 0 ; S < m ; S ++ )
					tmp[ c[ S ] ] = max( tmp[ c[ S ] ] , dp[ l ][ r ][ S ] + w[ S ] );
				dp[ l ][ r ][ 0 ] = tmp[ 0 ] , dp[ l ][ r ][ 1 ] = tmp[ 1 ]; 
			}
		}
	
//	for( int l = 1 ; l <= n ; l ++ )
//		for( int r = l ; r <= n ; r ++ )
//			for( int S = 0 ; S < m ; S ++ )
//				printf("dp[%d][%d][%d]=%lld\n",l,r,S,dp[l][r][S]);
	
	long long Ans = -INF;
	for( int i = 0 ; i < m ; i ++ )
		Ans = max( Ans , dp[ 1 ][ n ][ i ] );
	printf("%lld", Ans );
	return 0;
} 

5.P4302 [SCOI2003]字符串折叠

dp[l][r]dp[l][r] 表示将 [l,r][l,r] 折叠后的最小长度。

特别的,dp[i][i]=1dp[i][i]=1

那么有两种转移:

1.枚举划分点,将两区间长度相加。

2.将整个区间折叠,同样枚举划分点,判断是否能以该划分点折叠。

时间复杂度 Θ(n3logn)\Theta(n^3 logn)

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

const int MAXN = 100;
int n , dp[ MAXN + 5 ][ MAXN + 5 ];
char str[ MAXN + 5 ];

bool check( int l , int r , int len ) {
	for( int i = l ; i <= l + len - 1 ; i ++ )
		for( int j = i ; j <= r ; j += len )
			if( str[ i ] != str[ j ] ) return 0;
	return 1;
}
int chk( int x ) {
	int d = 0;
	for( ; x ; x /= 10 , d ++ );
	return d;
}

int main( ) {
	scanf("%s", str );
	n = strlen( str );
	
	memset( dp , 0x3f , sizeof( dp ) );
	for( int i = 0 ; i < n ; i ++ ) dp[ i ][ i ] = 1;
	for( int len = 2 ; len <= n ; len ++ )
		for( int l = 0 , r ; l + len - 1 < n ; l ++ ) {
			r = l + len - 1;
			for( int k = l ; k < r ; k ++ )
				dp[ l ][ r ] = min( dp[ l ][ r ] , dp[ l ][ k ] + dp[ k + 1 ][ r ] );
			for( int k = l ; k < r ; k ++ ) {
				int flen = k - l + 1;
				if( len % flen == 0 && check( l , r , flen ) )
					dp[ l ][ r ] = min( dp[ l ][ r ] , chk( len / flen ) + 1 + dp[ l ][ k ] + 1 );
			}
		}
	printf("%d\n", dp[ 0 ][ n - 1 ] );
	return 0;
}